{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Verify Exported ONNX Model in FINN\n",
    "\n",
    "<font color=\"red\">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for \"latency hiding\".</font>\n",
    "\n",
    "**Important: This notebook depends on the 1-train-mlp-with-brevitas notebook, because we are using the ONNX model that was exported there. So please make sure the needed .onnx file is generated before you run this notebook.**\n",
    "\n",
    "**Also remember to 'close and halt' any other FINN notebooks, since Netron visualizations use the same port.**\n",
    "\n",
    "In this notebook we will show how to import the network we trained in Brevitas and verify it in the FINN compiler. \n",
    "This verification process can actually be done at various stages in the compiler [as explained in this notebook](../bnn-pynq/tfc_end2end_verification.ipynb) but for this example we'll only consider the first step: verifying the exported high-level FINN-ONNX model.\n",
    "Another goal of this notebook is to introduce you to the concept of *graph transformations* -- we'll be applying some transformations to the graph to make it executable for verification. \n",
    "Once this model is sucessfully verified, we'll generate an FPGA accelerator from it in the next notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx \n",
    "import torch "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "-------------\n",
    "1. [Import model into FINN with ModelWrapper](#brevitas_import_visualization)\n",
    "2. [Network preparations: Tidy-up transformations](#network_preparations)\n",
    "3. [Load the dataset and Brevitas model](#load_dataset) \n",
    "4. [Compare FINN and Brevitas execution](#compare_brevitas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Import model into FINN with ModelWrapper <a id=\"brevitas_import_visualization\"></a>\n",
    "\n",
    "Now that we have the model in .onnx format, we can work with it using FINN. To import it into FINN, we'll use the [`ModelWrapper`](https://finn.readthedocs.io/en/latest/source_code/finn.core.html#qonnx.core.modelwrapper.ModelWrapper). It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from qonnx.core.modelwrapper import ModelWrapper\n",
    "\n",
    "model_dir = os.environ['FINN_ROOT'] + \"/notebooks/end2end_example/cybersecurity\"\n",
    "ready_model_filename = model_dir + \"/cybsec-mlp-ready.onnx\"\n",
    "model_for_sim = ModelWrapper(ready_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a look at some of the member functions exposed by `ModelWrapper` to see what kind of information we can extract from it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir(model_for_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Many of these helper functions relate to extracting information about the structure and properties of the ONNX model. You can find out more about examining and manipulating ONNX models programmatically in [this tutorial](../../basics/0_how_to_work_with_onnx.ipynb), but we'll show a few basic functions here. For instance, we can extract the shape and datatype annotation for various tensors in the graph, as well as information related to the operation types associated with each node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qonnx.core.datatype import DataType\n",
    "\n",
    "finnonnx_in_tensor_name = model_for_sim.graph.input[0].name\n",
    "finnonnx_out_tensor_name = model_for_sim.graph.output[0].name\n",
    "print(\"Input tensor name: %s\" % finnonnx_in_tensor_name)\n",
    "print(\"Output tensor name: %s\" % finnonnx_out_tensor_name)\n",
    "finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)\n",
    "finnonnx_model_out_shape = model_for_sim.get_tensor_shape(finnonnx_out_tensor_name)\n",
    "print(\"Input tensor shape: %s\" % str(finnonnx_model_in_shape))\n",
    "print(\"Output tensor shape: %s\" % str(finnonnx_model_out_shape))\n",
    "finnonnx_model_in_dt = model_for_sim.get_tensor_datatype(finnonnx_in_tensor_name)\n",
    "finnonnx_model_out_dt = model_for_sim.get_tensor_datatype(finnonnx_out_tensor_name)\n",
    "print(\"Input tensor datatype: %s\" % str(finnonnx_model_in_dt.name))\n",
    "print(\"Output tensor datatype: %s\" % str(finnonnx_model_out_dt.name))\n",
    "print(\"List of node operator types in the graph: \")\n",
    "print([x.op_type for x in model_for_sim.graph.node])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the output tensor is (as of yet) marked as a float32 value, even though we know the output is binary. This will be automatically inferred by the compiler in the next step when we run the `InferDataTypes` transformation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Network preparation: Tidy-up transformations <a id=\"network_preparations\"></a>\n",
    "\n",
    "Before running the verification, we need to prepare our FINN-ONNX model. In particular, all the intermediate tensors need to have statically defined shapes. To do this, we apply some graph transformations to the model like a kind of \"tidy-up\" to make it easier to process. \n",
    "\n",
    "**Graph transformations in FINN.** The whole FINN compiler is built around the idea of transformations, which gradually transform the model into a synthesizable hardware description. Although FINN offers functionality that automatically calls a standard sequence of transformations (covered in the next notebook), you can also manually call individual transformations (like we do here), as well as adding your own transformations, to create custom flows. You can read more about these transformations in [this notebook](../bnn-pynq/tfc_end2end_example.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs\n",
    "from qonnx.transformation.infer_shapes import InferShapes\n",
    "from qonnx.transformation.infer_datatypes import InferDataTypes\n",
    "from qonnx.transformation.fold_constants import FoldConstants\n",
    "\n",
    "model_for_sim = model_for_sim.transform(InferShapes())\n",
    "model_for_sim = model_for_sim.transform(FoldConstants())\n",
    "model_for_sim = model_for_sim.transform(GiveUniqueNodeNames())\n",
    "model_for_sim = model_for_sim.transform(GiveReadableTensorNames())\n",
    "model_for_sim = model_for_sim.transform(InferDataTypes())\n",
    "model_for_sim = model_for_sim.transform(RemoveStaticGraphInputs())\n",
    "\n",
    "verif_model_filename = model_dir + \"/cybsec-mlp-verification.onnx\"\n",
    "model_for_sim.save(verif_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Would the FINN compiler still work if we didn't do this?** The compilation step in the next notebook applies these transformations internally and would work fine, but we're going to use FINN's verification capabilities below and these require the tidy-up transformations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's view our ready-to-go model after the transformations. Note that all intermediate tensors now have their shapes specified (indicated by numbers next to the arrows going between layers). Additionally, the datatype inference step has propagated quantization annotations to the outputs of `MultiThreshold` layers (expand by clicking the + next to the name of the tensor to see the quantization annotation) and the final output tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.util.visualization import showInNetron\n",
    "\n",
    "showInNetron(verif_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Load the Dataset and the Brevitas Model <a id=\"load_dataset\"></a>\n",
    "\n",
    "We'll use some example data from the quantized UNSW-NB15 dataset (from the previous notebook) to use as inputs for the verification. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "def get_preqnt_dataset(data_dir: str, train: bool):\n",
    "    unsw_nb15_data = np.load(data_dir + \"/unsw_nb15_binarized.npz\")\n",
    "    if train:\n",
    "        partition = \"train\"\n",
    "    else:\n",
    "        partition = \"test\"\n",
    "    part_data = unsw_nb15_data[partition].astype(np.float32)\n",
    "    part_data = torch.from_numpy(part_data)\n",
    "    part_data_in = part_data[:, :-1]\n",
    "    part_data_out = part_data[:, -1]\n",
    "    return TensorDataset(part_data_in, part_data_out)\n",
    "\n",
    "n_verification_inputs = 100\n",
    "test_quantized_dataset = get_preqnt_dataset(\".\", False)\n",
    "input_tensor = test_quantized_dataset.tensors[0][:n_verification_inputs]\n",
    "input_tensor.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's also bring up the MLP we trained in Brevitas from the previous notebook. We'll compare its outputs to what is generated by FINN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 593      \n",
    "hidden1 = 64      \n",
    "hidden2 = 64\n",
    "hidden3 = 64\n",
    "weight_bit_width = 2\n",
    "act_bit_width = 2\n",
    "num_classes = 1\n",
    "\n",
    "from brevitas.nn import QuantLinear, QuantReLU\n",
    "import torch.nn as nn\n",
    "\n",
    "brevitas_model = nn.Sequential(\n",
    "      QuantLinear(input_size, hidden1, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden1),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden1, hidden2, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden2),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden2, hidden3, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden3),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width)\n",
    ")\n",
    "\n",
    "# replace this with your trained network checkpoint if you're not\n",
    "# using the pretrained weights\n",
    "trained_state_dict = torch.load(model_dir + \"/state_dict.pth\")[\"models_state_dict\"][0]\n",
    "\n",
    "# Uncomment the following line if you previously chose to train the network yourself\n",
    "#trained_state_dict = torch.load(\"state_dict_self-trained.pth\")\n",
    "\n",
    "brevitas_model.load_state_dict(trained_state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_with_brevitas(current_inp):\n",
    "    brevitas_output = brevitas_model.forward(current_inp)\n",
    "    # apply sigmoid + threshold\n",
    "    brevitas_output = torch.sigmoid(brevitas_output)\n",
    "    brevitas_output = (brevitas_output.detach().numpy() > 0.5) * 1\n",
    "    # convert output to bipolar\n",
    "    brevitas_output = 2*brevitas_output - 1\n",
    "    return brevitas_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Compare FINN & Brevitas execution <a id=\"compare_brevitas\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make helper functions to execute the same input with Brevitas and FINN. For FINN, we'll use the [`finn.core.onnx_exec`](https://finn.readthedocs.io/en/latest/source_code/finn.core.html#finn.core.onnx_exec.execute_onnx) function to execute the exported FINN-ONNX on the inputs. Note that this ONNX execution is for verification only; not for accelerated execution.\n",
    "\n",
    "Recall that the quantized values from the dataset are 593-bit binary {0, 1} vectors whereas our exported model takes 600-bit bipolar {-1, +1} vectors, so we'll have to preprocess it a bit before we can use it for verifying the ONNX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import finn.core.onnx_exec as oxe\n",
    "\n",
    "def inference_with_finn_onnx(current_inp):\n",
    "    finnonnx_in_tensor_name = model_for_sim.graph.input[0].name\n",
    "    finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)\n",
    "    finnonnx_out_tensor_name = model_for_sim.graph.output[0].name\n",
    "    # convert input to numpy for FINN\n",
    "    current_inp = current_inp.detach().numpy()\n",
    "    # add padding and re-scale to bipolar\n",
    "    current_inp = np.pad(current_inp, [(0, 0), (0, 7)])\n",
    "    current_inp = 2*current_inp-1\n",
    "    # reshape to expected input (add 1 for batch dimension)\n",
    "    current_inp = current_inp.reshape(finnonnx_model_in_shape)\n",
    "    # create the input dictionary\n",
    "    input_dict = {finnonnx_in_tensor_name : current_inp} \n",
    "    # run with FINN's execute_onnx\n",
    "    output_dict = oxe.execute_onnx(model_for_sim, input_dict)\n",
    "    #get the output tensor\n",
    "    finn_output = output_dict[finnonnx_out_tensor_name] \n",
    "    return finn_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can call our inference helper functions for each input and compare the outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import trange\n",
    "\n",
    "verify_range = trange(n_verification_inputs, desc=\"FINN execution\", position=0, leave=True)\n",
    "brevitas_model.eval()\n",
    "\n",
    "ok = 0\n",
    "nok = 0\n",
    "\n",
    "for i in verify_range:\n",
    "    # run in Brevitas with PyTorch tensor\n",
    "    current_inp = input_tensor[i].reshape((1, 593))\n",
    "    brevitas_output = inference_with_brevitas(current_inp)\n",
    "    finn_output = inference_with_finn_onnx(current_inp)\n",
    "    # compare the outputs\n",
    "    ok += 1 if finn_output == brevitas_output else 0\n",
    "    nok += 1 if finn_output != brevitas_output else 0\n",
    "    verify_range.set_description(\"ok %d nok %d\" % (ok, nok))\n",
    "    verify_range.refresh()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    assert ok == n_verification_inputs\n",
    "    print(\"Verification succeeded. Brevitas and FINN-ONNX execution outputs are identical\")\n",
    "except AssertionError:\n",
    "    assert False, \"Verification failed. Brevitas and FINN-ONNX execution outputs are NOT identical\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This concludes our second notebook. In the next one, we'll take the ONNX model we just verified all the way down to FPGA hardware with the FINN compiler."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
